import torch
import torch.nn as nn
import time

lin = nn.Linear(1000, 10).to(torch.device("cuda"))
assert torch.cuda.device_count() > 1
model = nn.DataParallel(lin, device_ids=list(range(torch.cuda.device_count())))
for _ in range(20):
    with  torch.no_grad():
        ipt_ids = torch.randn([800, 1000]).cuda()
        rst = model(ipt_ids).softmax(dim=1)
time.sleep(2)